"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
This file implements a self-debugging agent that generates and validates JAX code
from PyTorch code. The agent uses a language model to perform code conversion,
generate test cases, and debug the generated code if it fails.

The process involves the following steps for each pair of PyTorch and JAX files:
1.  **Code Generation**: The agent converts a PyTorch file to JAX code using a
    language model. It includes retries to fix syntax errors.
2.  **Test Case Generation**: A pytest-compatible test case is generated by an
    LLM to compare the behavior of the original PyTorch code and the newly
    generated JAX code.
3.  **Test Execution**: The generated test case is run to check for correctness.
    The output, including passed and failed tests, is captured.
4.  **Self-Debugging**: If the tests fail, the agent uses a debugging prompt
    and the test output to have the language model fix the JAX code and/or the
    test case. This step also includes multiple retries.
5.  **Result Logging**: The results of the tests and debugging attempts are
    logged, and overall accuracy metrics are calculated.

Overall Accuracy Metrics:
-   **Test Case Accuracy**: The percentage of individual test cases that passed
    across all generated tests.
-   **File Accuracy**: The percentage of files for which all generated test
    cases passed.

Example Invocation:
To debug all the failed cases from the entire folder, run: `python self_debugging_agent.py`, make sure that the PyTorch
files are present in the `pytorch_path` folder.
To debug a single file, run: `python self_debugging_agent.py --module_name <file_name>.py`, make sure that the PyTorch
file is present in the `pytorch_path` folder.

Ensure that the paths for PyTorch and JAX code directories are correctly set
in the script arguments.

Relevant Files:
-   `prompt_code_debugging.py`: Contains prompts for the language model to
    perform code debugging.
-   `code_evaluation_agent/prompt_evaluation.py`: Contains prompts for
    generating initial test cases.
-   `code_evaluation_agent/utils.py`: Provides utilities like `get_last_defined_module`
    and `run_pytest_capture_output`.
-   `code_generation_agent/llm_agent.py`: The `GeminiAgent` class for language
    model interaction.
-   `orchestration_agent/utils.py`: Contains `parse_python_code` for extracting
    code from LLM responses.
-   `utils.py`: Provides utility functions like `save_in_file_and_check_code_syntax`
    and `parse_json_response`.
"""

import argparse
import json
import logging
import os

from MaxText.experimental.agent.code_generation_agent.llm_agent import GeminiAgent
from MaxText.experimental.agent.code_evaluation_agent.prompt_code_evaluation import CodeEvaluation
from MaxText.experimental.agent.code_evaluation_agent.utils import get_last_defined_module, run_pytest_capture_output
from MaxText.experimental.agent.orchestration_agent.utils import parse_python_code
from MaxText.experimental.agent.code_generation_agent.llm_code_generation import convert_code_from_torch_to_jax
from MaxText.experimental.agent.self_debugging_agent.utils import save_in_file_and_check_code_syntax, parse_json_response, smartly_copy_code
from MaxText.experimental.agent.self_debugging_agent.prompt_debugging import CodeDebugging

logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)


def get_file_pairs(module_name, pytorch_path, jax_path):
  """
  This function finds all Python files in the specified PyTorch directory and
  creates corresponding file paths for the JAX directory.

  Args:
      module_name: Iterable of module name
      pytorch_path: The path to the directory containing PyTorch files.
      jax_path: The path to the directory where JAX files will be stored.

  Returns:
      A tuple containing two lists of strings:
          - The first list contains the full paths to the PyTorch files.
          - The second list contains the corresponding full paths for the JAX files.
  """
  pytorch_files = list(filter(lambda x: x.endswith(".py"), os.listdir(pytorch_path)))
  if module_name is not None:
    assert all(
        os.path.exists(os.path.join(pytorch_path, m_name)) for m_name in module_name
    ), f"{module_name} should be present at {pytorch_path}"
    pytorch_files = list(filter(lambda x: x in module_name, pytorch_files))
  return list(map(lambda x: pytorch_path + x, pytorch_files)), list(map(lambda x: jax_path + x, pytorch_files))


def generate_test_case(python_file, entry_module, python_code, jax_code, jax_file, test_file_path):
  """Generates a pytest test case file for a given PyTorch and JAX code pair.

  This function uses a language model to generate a test case that compares the
  behavior of the PyTorch and JAX implementations. The generated test code is
  then saved to a specified file path.

  Args:
      python_file: The path to the original PyTorch code file.
      entry_module: The name of the main module (function or class) to be tested.
      python_code: The content of the PyTorch code file.
      jax_code: The content of the JAX code file.
      jax_file: The path where the JAX code file is or will be saved.
      test_file_path: The path where the generated test case should be saved.

  Returns:
      The generated test case code as a string.
  """
  prompt = CodeEvaluation["TESTCASE"]
  python_code = (
      f"from {'.'.join(python_file.split(os.path.sep)[1:]).removesuffix('.py')}"
      f" import {entry_module}\n\n"
      f"{python_code}"
  )
  jax_code = (
      f"from {".".join(jax_file.split(os.path.sep)[1:]).removesuffix('.py')}" f" import {entry_module}\n\n" f"{jax_code}"
  )
  prompt = prompt.replace("<module.path.to.pytorch_code>", python_code)
  prompt = prompt.replace("<module.path.to.jax_code>", jax_code)
  prompt = prompt.replace("<function_or_class_to_call>", entry_module)
  if "llm_debugging_agent" not in globals():
    globals()["llm_debugging_agent"] = GeminiAgent(CodeDebugging["SystemPrompt"])
  response = llm_evaluation_agent(prompt)
  generated_code = parse_python_code(response.text)
  with open(test_file_path, "wt", encoding="utf-8") as f:
    f.write("import os,sys\nsys.path.append(os.path.abspath('..'))\n")
    f.write(generated_code)
  logger.info("Written at %s", test_file_path)
  return generated_code


def save_and_run_test_case(jax_code, test_code, jax_file, test_file_path):
  """Saves JAX code and a test case to files and then runs the tests.

  This function writes the provided JAX code and test code to their respective
  files, then executes pytest to run the test case. It captures the output
  and test results.

  Args:
      jax_code: The JAX code to be saved.
      test_code: The test case code to be saved.
      jax_file: The path to the file where the JAX code will be written.
      test_file_path: The path to the file where the test case will be written.

  Returns:
      A tuple containing:
          - The captured output from the pytest execution.
          - The exit code of the pytest process.
          - A boolean indicating if a dependency error occurred.
          - The number of passed tests.
          - The number of failed tests.
  """
  with open(jax_file, "wt", encoding="utf-8") as f:
    f.write(jax_code)
  with open(test_file_path, "wt", encoding="utf-8") as f:
    f.write("import os,sys\nsys.path.append(os.path.abspath('..'))\n")
    f.write(test_code)
  file = test_file_path.split(os.path.sep)[-1]
  testcase_path = test_file_path.rsplit(os.path.sep, 1)[0]
  return run_pytest_capture_output(file, code_folder=testcase_path)


def code_debugging(args, python_file, jax_file, test_file_path, last_output, code_history, base_try):
  """Debugs JAX code using a language model based on test failures.

    This function takes the failed test output and provides it to a language model
    to generate a corrected JAX code or test case. It retries this process
    multiple times until the tests pass or the retry limit is reached.

    Args:
  c      python_file: The path to the PyTorch reference code.
        jax_file: The path to the JAX code file being debugged.
        test_file_path: The path to the test case file.
        last_output: The output from the last failed test run (stack trace).
        code_history: A list of dictionaries containing previous code states and
                      test results.
        base_try: The current attempt number for debugging.

    Returns:
        A tuple containing:
            - An integer exit code (0 for success, 1 for failure).
            - The number of passed tests from the final attempt.
            - The number of failed tests from the final attempt.
            - The updated code history list.
  """
  try:
    memory_list = []
    with open(python_file, "rt", encoding="utf-8") as f:
      python_code = f.read()
    jax_code, test_case_code = code_history[0]["jax_code"], code_history[0]["test_case_code"]
    entry_module = get_last_defined_module(python_code)
    passed, failed = 0, 0
    for try_count in range(args.code_debug_error_tries):
      logger.info("Running Self Debugging for %d time for %s", try_count + 1, python_file)
      if try_count == 0:
        prompt = CodeDebugging["CODE"]
        python_code = (
            "from "
            + ".".join(python_file.split(os.path.sep)[1:]).replace(".py", " import " + entry_module)
            + "\n\n"
            + python_code
        )
        jax_code = (
            "from "
            + ".".join(jax_file.split(os.path.sep)[1:]).replace(".py", " import " + entry_module)
            + "\n\n"
            + jax_code
        )
        prompt = prompt.replace("<module.path.to.pytorch_code>", python_code)
        prompt = prompt.replace("<module.path.to.jax_code>", jax_code)
        prompt = prompt.replace("<function_or_class_to_call>", entry_module)
        prompt = prompt.replace("<PyTorch_Code>", python_code)
        prompt = prompt.replace("<JAX_code>", jax_code)
        prompt = prompt.replace("<pytest_test_code>", test_case_code)
        prompt = prompt.replace("<stack_trace>", last_output)
      else:
        prompt = CodeDebugging["FollowUpPrompt"].replace("<stack_trace>", last_output)
      memory_list.append({"role": "user", "parts": prompt})
      resp = llm_debugging_agent(memory_list)
      memory_list.append({"role": "model", "parts": resp.text})
      json_code = parse_json_response(resp.text)
      if len(json_code["jax_code"]) > 5 and json_code["jax_code"] != "NOJAXCODE":
        jax_code = json_code["jax_code"]
      if len(json_code["test_code"]) > 5 and json_code["jax_code"] != "NOTESTCASE":
        test_case_code = json_code["test_code"]
      last_output, exit_code, _, passed, failed = save_and_run_test_case(
          jax_code, test_case_code, jax_file, test_file_path
      )
      code_history.append({"passed": passed, "failed": failed, "jax_code": jax_code, "test_case_code": test_case_code})
      if passed > 0 and failed == 0:
        logger.info("Code Debugger able to solve the bugs in %s in %d Try ", args.jax_path, try_count)
        return exit_code, passed, failed, code_history
    return 1, passed, failed, code_history
  except (IOError, KeyError, AttributeError, json.JSONDecodeError) as e:
    print(f"Exception in code debugging {e}")
    return 1, 0, args.error_penalty, code_history


def make_code_and_debug(args, python_file, jax_file):
  """Generates JAX code from PyTorch, generates a test case, and performs
  self-debugging if the tests fail.

  This is the main function that orchestrates the code generation and
  debugging process for a single pair of PyTorch and JAX files. It
  manages multiple tries for code generation, syntax error fixes, and
  debugging attempts.

  Args:
      args (argparse.Namespace): CLI arguments
      python_file: The path to the PyTorch code file.
      jax_file: The path where the generated JAX code will be stored.

  Returns:
      A tuple containing the number of passed and failed test cases
      from the final successful or best-effort attempt.
  """
  assert os.path.exists(args.pytorch_path), f"python file {python_file} not exists"
  try:
    # copy code if exists or generate if not
    if smartly_copy_code(
        python_file.split(os.path.sep)[-1],
        base_jax_path=args.base_jax_path,
        base_testcase_path=args.base_testcase_path,
        dest_jax_path=args.jax_path,
        dest_testcase_path=args.testcase_path,
    ):
      logger.info("Copied code for %s from Single run code", python_file.split(os.path.sep)[-1])
    with open(python_file, "rt", encoding="utf-8") as f:
      python_code = f.read()
    entry_module = get_last_defined_module(python_code)
    test_file_path = os.path.join(args.testcase_path, python_file.split(os.path.sep)[-1])
    code_history = []
    for base_try in range(args.code_generation_tries):
      logger.info("Processing %s", python_file)
      if base_try > 0 or not os.path.exists(jax_file):
        if base_try == 0:
          logger.info("No jax file exists for %s so generating that", python_file.split(os.path.sep)[-1])
        for syntax_index in range(args.code_syntax_error_tries):
          jax_code, _ = convert_code_from_torch_to_jax(python_code, [])
          if jax_code == "<NOCHANGE>":
            logger.info("No change in code Using original code")
            jax_code = python_code
            break
          if save_in_file_and_check_code_syntax(jax_code, jax_file)[0]:
            logger.error("It seems JAX have syntax error so regenerating %d time", syntax_index + 2)
          elif get_last_defined_module(jax_code) != entry_module:
            logger.error(
                "It seems inconsistency in %s code PyTorch have %s and JAX have %s as entry Module so regenerating",
                python_file,
                entry_module,
                get_last_defined_module(jax_code),
            )
          else:
            break
        else:
          logger.error(
              "Not able to solve the syntax error in %s in %d tries for %d times",
              jax_file,
              args.code_syntax_error_tries,
              base_try + 1,
          )
          continue
      else:
        logger.info("Checking test case for existing code for %s", python_file.split(os.path.sep)[-1])
        with open(jax_file, "rt", encoding="utf-8") as f:
          jax_code = f.read()
      if base_try > 0 or not os.path.exists(test_file_path):
        test_case_code = generate_test_case(
            python_file, entry_module, python_code, args.jax_code, jax_file, test_file_path
        )
        if test_case_code == "NOTESTCASE":
          logger.info("Test case is not possible")
          return 1, 0
        if "<UNABLETOGENERATE>" in test_case_code:
          return 0, args.error_penalty
      else:
        with open(test_file_path, "rt", encoding="utf-8") as f:
          test_case_code = f.read()
      file = test_file_path.split(os.path.sep)[-1]
      output, exit_code, is_dependency_error, num_passed, num_failed = run_pytest_capture_output(
          file, code_folder=args.testcase_path
      )
      if num_passed > 0 and num_failed == 0:
        # find the working code. no need to debug
        return num_passed, num_failed
      else:
        if is_dependency_error:
          logger.error("There are some missing dependency please check %s", output)
        code_history.append(
            {"passed": num_passed, "failed": num_failed, "jax_code": args.jax_code, "test_case_code": args.test_case_code}
        )
        exit_code, num_passed, num_failed, code_history = code_debugging(
            args, python_file, jax_file, test_file_path, output, code_history, base_try
        )
        if exit_code == 0 and num_passed > 0 and num_failed == 0:
          return num_passed, num_failed
    # no code with all test case passed find the best one
    if len(code_history) == 0:
      logger.info("Code %s have some issue LLM not able to solve", args.jax_code)
      return 0, args.error_penalty
    best_code = max(
        code_history, key=lambda x: x["passed"] / (x["passed"] + x["failed"] if (x["passed"] + x["failed"]) > 0 else 0)
    )
    with open(jax_file, "wt", encoding="utf-8") as f:
      f.write(best_code["jax_code"])
    with open(test_file_path, "wt", encoding="utf-8") as f:
      f.write("import os,sys\nsys.path.append(os.path.abspath('..'))\n")
      f.write(best_code["test_case_code"])
    return best_code["passed"], best_code["failed"]

  except (IOError, KeyError, AttributeError, json.JSONDecodeError) as e:
    logger.error("Exception in code generation %s", e)
    logger.error("The code file is %s", python_file.split(os.path.sep)[-1])
    # Penalty in case of Exception
    return 0, args.error_penalty


def run_self_debugging_code_generation(args):
  """
  Runs the full self-debugging code generation and evaluation process.

  This function orchestrates the conversion of PyTorch code to JAX, generates a
  test case for each pair, and uses a self-debugging loop to fix any test failures.
  It logs the results for each file and calculates overall accuracy metrics
  (test case and file accuracy).

  Args:
    args (argparse.Namespace): CLI arguments
  """
  total_passed, total_failed = 0, 0
  all_passed, all_failed, total_files = 0, 0, 0
  for python_file, jax_file in zip(*get_file_pairs(args.module_name, args.pytorch_path, args.jax_path)):
    num_passed, num_failed = make_code_and_debug(args, python_file, jax_file)
    if num_passed == num_failed == 0:  # when the code cannot be executed
      # Penalty in case of issue in test case and not executed
      num_failed = args.error_penalty
    logger.info("%s have %d cases passed and %d cases failed", python_file.split(os.path.sep)[-1], num_passed, num_failed)
    total_passed += num_passed
    total_failed += num_failed
    if num_passed == 0:
      all_failed += 1
    if num_failed == 0:
      all_passed += 1
    total_files += 1

  logger.info("****** Results for Self Debugging ******")
  logger.info("%d files have all module passed %d files have all module failed", all_passed, all_failed)
  logger.info("Test case Accuracy %.2f%%", total_passed * 100 / (total_passed + total_failed))
  logger.info("File Accuracy %.2f%%", all_passed * 100 / total_files)


def main():
  _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
  _AGENT_DIR = os.path.dirname(_SCRIPT_DIR)

  parser = argparse.ArgumentParser(description="Code Evaluation Agent")
  parser.add_argument(
      "--code_syntax_error_tries", type=int, default=5, help="for how many times to tries in case wrong syntax."
  )
  parser.add_argument(
      "--code_debug_error_tries", type=int, default=5, help="for how many times to tries in case code failed."
  )
  parser.add_argument(
      "--code_generation_tries",
      type=int,
      default=2,
      help="for how many times to tries in case code failed and debugging also failed.",
  )
  parser.add_argument(
      "--error_penalty", type=int, default=10, help="Penalty for errors in test case generation or execution."
  )
  parser.add_argument(
      "--module_name",
      type=str,
      nargs="+",
      help="Name of one or more modules to process. If not provided, all modules in the pytorch_path will be processed.",
  )
  parser.add_argument(
      "--pytorch_path",
      type=str,
      default=os.path.join(_AGENT_DIR, "code_generation_agent/dataset/PyTorch/"),
      help="Path to the directory containing PyTorch files.",
  )
  parser.add_argument(
      "--jax_path",
      type=str,
      default=os.path.join(_SCRIPT_DIR, "dataset/jax_converted/"),
      help="Path to the directory containing JAX files.",
  )
  parser.add_argument(
      "--testcase_path",
      type=str,
      default=os.path.join(_SCRIPT_DIR, "dataset/test_cases/"),
      help="Path to the directory for generated test cases.",
  )
  parser.add_argument(
      "--base_jax_path",
      type=str,
      default=os.path.join(_AGENT_DIR, "code_generation_agent/dataset/jax_converted/"),
      help="Base path for JAX files.",
  )
  parser.add_argument(
      "--base_testcase_path",
      type=str,
      default=os.path.join(_AGENT_DIR, "code_generation_agent/dataset/test_cases/"),
      help="Base path for test cases.",
  )
  _args = parser.parse_args()

  os.makedirs(_args.jax_path, exist_ok=True)
  os.makedirs(_args.testcase_path, exist_ok=True)
  return _args


llm_evaluation_agent = GeminiAgent(CodeEvaluation["SystemPrompt"])
llm_debugging_agent = GeminiAgent(CodeDebugging["SystemPrompt"])

if __name__ == "__main__":
  run_self_debugging_code_generation(main())
